"""
@Project: pythonPro1
@Name: mean_shift7.py
@Author: linxin_liu
@Date: 2022/11/7 21:33
"""
import numpy as np
from matplotlib import pyplot as plt
from sklearn.cluster import MeanShift, estimate_bandwidth
bandwidth = estimate_bandwidth(x, quantile=0.1)
ms = MeanShift(bandwidth=bandwidth).fit(x)
labels = ms.labels_
labels_unique = np.unique(labels)
n_clusters_ = len(labels_unique)
print('聚类个数: ' + str(n_clusters_))
y_pred  = ms.predict(x)
plt.scatter(ms.cluster_centers_[:,0], ms.cluster_centers_[:, 1], s = 50, c = 'blue' , label = 'centeroid')
plt.scatter(x[:, 0], x[:, 1], c=y_pred, cmap="rainbow")
plt.style.use('fivethirtyeight')
plt.title('Mean Shift Clustering')
plt.xlabel('Annual Income')
plt.ylabel('Spending Score')
plt.legend()
plt.grid()
plt.show()